/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once
#include <cuda_runtime_api.h>

#include <array>
#include <optional>
#include <vector>

#include "./common.h"
#include "cute/tensor.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/layout/layout.h"
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/workspace.h"
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h"

#ifdef ENABLE_FP4
#include <cuda_fp4.h>
#endif

namespace tensorrt_llm::kernels::cutlass_kernels {

template <typename AType, typename BType, typename BScaleType, typename OType>
struct GroupedGemmInput {
  AType const* A = nullptr;
  int64_t const* total_tokens_including_expert = nullptr;
  BType const* B = nullptr;
  BScaleType const* scales = nullptr;
  BScaleType const* zeros = nullptr;
  OType const* biases = nullptr;
  OType* C = nullptr;
  float const** alpha_scales = nullptr;
  int* occupancy = nullptr;

  ActivationType activation_type = ActivationType::InvalidType;
  int64_t num_rows = 0;
  int64_t n = 0;
  int64_t k = 0;
  int num_experts = 0;
  int const groupwise_quant_group_size = 0;

  bool bias_is_broadcast = true;
  bool use_fused_moe = false;

  cudaStream_t stream = 0;
  cutlass_extensions::CutlassGemmConfig gemm_config;
};

struct TmaWarpSpecializedGroupedGemmInput {
  template <class Tag>
  using TransposeLayoutTag =
      std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
                         cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;

  static_assert(
      std::is_same_v<cutlass::layout::RowMajor, TransposeLayoutTag<cutlass::layout::ColumnMajor>>);
  static_assert(
      std::is_same_v<cutlass::layout::ColumnMajor, TransposeLayoutTag<cutlass::layout::RowMajor>>);

  // These are always the layout of A & B matrices, activations and weights will be assigned to
  // either A or B based on swap_ab
  using LayoutA = cutlass::layout::RowMajor;
  using LayoutB = cutlass::layout::ColumnMajor;

  // When using Swap A&B we need to transpose the output matrix
  using LayoutC = cutlass::layout::RowMajor;
  using LayoutD = cutlass::layout::RowMajor;
  using LayoutC_T = TransposeLayoutTag<LayoutC>;
  using LayoutD_T = TransposeLayoutTag<LayoutD>;

  using StrideA = std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
  using StrideB = std::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;

  using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
  using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
  using StrideC_T = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC_T*>>;
  using StrideD_T = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD_T*>>;

  constexpr static int NVFP4BlockScaleVectorSize = 16;
  constexpr static int MXFPXBlockScaleVectorSize = 32;

  using NVFP4BlockScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig<NVFP4BlockScaleVectorSize>;
  using MXFPXBlockScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig<MXFPXBlockScaleVectorSize>;

  // 128
  // This is the alignment of the weight matrix the fully padded SF will refer to.
  // We require the SFs to be aligned to this value (zero padded as needed)
  // The weights do not need to be aligned to this value, CUTLASS will handle extra padding
  // N here is a short hand for the outer dimension of the GEMM, this applies to both M & N
  // dimension of the GEMM
  constexpr static int MinNDimAlignmentNVFP4 = cute::size<0>(NVFP4BlockScaledConfig::SfAtom{});
  constexpr static int MinNDimAlignmentMXFPX = cute::size<0>(MXFPXBlockScaledConfig::SfAtom{});

  // Block scale vector size * 4
  // This is the alignment of the weight matrix the fully padded SF will refer to.
  // We should never actually need to pad a buffer to this alignment
  // The weights only need to be aligned to BlockScaleVectorSize, CUTLASS will handle extra padding
  // The SFs only need to be aligned to 4 (zero padded as needed)
  // K here is a short hand for the inner dimension of the GEMM
  constexpr static int MinKDimAlignmentNVFP4 = cute::size<1>(NVFP4BlockScaledConfig::SfAtom{});
  constexpr static int MinKDimAlignmentMXFPX = cute::size<1>(MXFPXBlockScaledConfig::SfAtom{});

  // Helper function to align a dimension to the SF alignment
  constexpr static int64_t alignToSfDim(int64_t dim, int64_t alignment) {
    return (dim + alignment - 1) / alignment * alignment;
  }

#ifdef ENABLE_FP8
  template <class T>
  constexpr static bool IsFP8_v =
      std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
#else
  template <class T>
  constexpr static bool IsFP8_v = false;
#endif

  // Currently this should always just be T
  template <class T>
  using OutputTypeAdaptor_t = std::conditional_t<IsFP8_v<T>, nv_bfloat16, T>;

  using ProblemShape = cutlass::gemm::GroupProblemShape<cute::Shape<int64_t, int64_t, int64_t>>;

  bool swap_ab = false;
  ProblemShape shape_info{};
  void* stride_act = nullptr;
  void* stride_weight = nullptr;

  void const** ptr_act = nullptr;
  void const** ptr_weight = nullptr;

  // C is currently the same in both epilogues
  void* stride_c = nullptr;
  void const** ptr_c = nullptr;

  // D is used in all cases except fused finalize
  void* stride_d = nullptr;
  void** ptr_d = nullptr;

  struct FusedFinalizeEpilogue {
    using StrideFinalOutput_T = cutlass::detail::TagToStrideC_t<LayoutD_T>;
    using StrideFinalOutput = cutlass::detail::TagToStrideC_t<LayoutD>;

    void* ptr_final_output = nullptr;
    StrideFinalOutput_T stride_final_output_transposed{};
    StrideFinalOutput stride_final_output{};

    void const** ptr_bias = nullptr;
    float const** ptr_router_scales = nullptr;

    int const** ptr_source_token_index = nullptr;
    int num_rows_in_final_output = 0;
    int shape_override = -1;

    bool use_reduction = true;
  };

  FusedFinalizeEpilogue fused_finalize_epilogue;

  enum class EpilogueFusion { NONE, ACTIVATION, GATED_ACTIVATION, FINALIZE };
  EpilogueFusion fusion = EpilogueFusion::NONE;

  float const** alpha_scale_ptr_array = nullptr;

  using ElementSF = uint8_t;
  using MXFPXElementSF = ElementSF;  // Just an alias for now
  using NVFP4ElementSF = ElementSF;  // Just an alias for now
  ElementSF const** fpX_block_scaling_factors_act = nullptr;
  ElementSF const** fpX_block_scaling_factors_weight = nullptr;

  void* fpX_block_scaling_factors_stride_act = nullptr;
  void* fpX_block_scaling_factors_stride_weight = nullptr;

  enum class FpXBlockScalingType { MXFPX, NVFP4, NONE };
  FpXBlockScalingType fpX_block_scaling_type = FpXBlockScalingType::NONE;

  struct INT4GroupwiseParams {
    constexpr static int int4_group_size = 128;
    constexpr static int wfp4a16_group_size = 32;
    bool enabled = false;
    bool use_wfp4a16 = false;
    using SFA = __nv_bfloat16;
    using SFB = __nv_bfloat16;  // Unused
    using ProblemShapeInt = cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
    using LayoutSFA = typename cutlass::layout::ColumnMajor;
    using LayoutSFB = typename cutlass::layout::ColumnMajor;  // Unused
    using StrideSFA = cute::Stride<cute::Int<1>, int64_t, int64_t>;
    using StrideSFB = cute::Stride<cute::Int<1>, int64_t, int64_t>;  // Unused
    StrideSFA* stride_s_a = nullptr;
    StrideSFB* stride_s_b = nullptr;  // Unused
    const SFA** ptr_s_a = nullptr;
    const SFA** ptr_z_a = nullptr;  // Unused
    const SFB** ptr_s_b = nullptr;  // Unused
    const SFB** ptr_z_b = nullptr;  // Unused
    ProblemShapeInt shape{};
  };

  INT4GroupwiseParams int4_groupwise_params;

  uint8_t* gemm_workspace = nullptr;
  size_t gemm_workspace_size = 0;

  // Whether to enable PDL (Programmatic Dependent Launch).
  bool enable_pdl{};

  static std::array<size_t, 20> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);

  static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type);

  void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace,
                          size_t gemm_workspace_size, FpXBlockScalingType scaling_type);

  bool isValid() const { return stride_act != nullptr && ptr_act != nullptr; }

  void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens,
                               bool use_reduction);

  std::string toString() const;
};

constexpr bool isGatedActivation(ActivationType activation_type) {
  return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu ||
         activation_type == ActivationType::SwigluBias;
}

template <typename T,                         /*The type used for activations/scales/compute*/
          typename WeightType,                /* The type for the MoE weights */
          typename OutputType,                /* The output type for the GEMM */
          typename ScaleBiasType = OutputType /* The type for the scales/bias */
          >
class MoeGemmRunner {
 public:
  MoeGemmRunner();

#if defined(ENABLE_FP4)
#if defined(ENABLE_BF16)
  static constexpr bool use_wfp4a16 = std::is_same_v<WeightType, __nv_fp4_e2m1> &&
                                      (std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>);
#else
  static constexpr bool use_wfp4a16 =
      std::is_same_v<WeightType, __nv_fp4_e2m1> && std::is_same_v<T, half>;
#endif
#else
  static constexpr bool use_wfp4a16 = false;
#endif
#if defined(ENABLE_FP8)
  static constexpr bool use_fp8 =
      (std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>) &&
      !std::is_same_v<WeightType, cutlass::uint4b_t>
#if defined(ENABLE_FP4)
      && !std::is_same_v<WeightType, __nv_fp4_e2m1>
#endif
      ;
  static constexpr bool use_w4afp8 =
      std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, cutlass::uint4b_t>;
#else
  static constexpr bool use_fp8 = false;
  static constexpr bool use_w4afp8 = false;
#endif
  static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;

#if defined(ENABLE_FP4)
  static constexpr bool use_fp4 = std::is_same_v<T, __nv_fp4_e2m1>;
  static constexpr bool use_wfp4afp8 =
      std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>;
#else
  static constexpr bool use_fp4 = false;
  static constexpr bool use_wfp4afp8 = false;
#endif

  void moeGemmBiasAct(GroupedGemmInput<T, WeightType, ScaleBiasType, OutputType> inputs,
                      TmaWarpSpecializedGroupedGemmInput hopper_inputs);

  void moeGemm(GroupedGemmInput<T, WeightType, ScaleBiasType, OutputType> inputs,
               TmaWarpSpecializedGroupedGemmInput hopper_inputs);

  std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(
      bool supports_finalize_fusion) const;
  static std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(
      int sm, bool supports_finalize_fusion);
  static std::vector<cutlass_extensions::CutlassGemmConfig> getTmaWarpSpecializedConfigs(
      int sm, bool supports_finalize_fusion);
  static std::vector<cutlass_extensions::CutlassGemmConfig> getAmpereConfigs(int sm);

  [[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const;

  [[nodiscard]] bool supportsTmaWarpSpecialized() const { return supportsTmaWarpSpecialized(sm_); }

  [[nodiscard]] static bool supportsTmaWarpSpecialized(int sm);
  [[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config,
                                            ActivationType activation_type, int gemm_n,
                                            int gemm_k) const;
  [[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n,
                                                  int gemm_k) const;

  size_t getMaxWorkspaceSize(int num_experts) const;

  [[nodiscard]] int getSM() const;

 private:
  template <typename EpilogueTag>
  void dispatchToArch(GroupedGemmInput<T, WeightType, ScaleBiasType, OutputType> inputs,
                      TmaWarpSpecializedGroupedGemmInput hopper_inputs);

  template <typename EpilogueTag>
  void runGemm(GroupedGemmInput<T, WeightType, ScaleBiasType, OutputType> inputs,
               TmaWarpSpecializedGroupedGemmInput hopper_inputs);

 private:
  int sm_{};
  int multi_processor_count_{};
  mutable int num_experts_ = 0;
  mutable size_t gemm_workspace_size_ = 0;
  size_t calcMaxWorkspaceSize(int num_experts) const;
};

}  // namespace tensorrt_llm::kernels::cutlass_kernels
